{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "834061f0",
   "metadata": {},
   "source": [
    "# Lesson 3: Connecting to a SQL Database"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34cb8641",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f50bdee1-25c0-472e-a739-c8551b5c5349",
   "metadata": {
    "height": 132
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from IPython.display import Markdown, HTML, display\n",
    "from langchain.chat_models import AzureChatOpenAI\n",
    "from langchain.agents import create_sql_agent\n",
    "from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
    "from langchain.sql_database import SQLDatabase\n",
    "from langchain_openai import AzureChatOpenAI"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "730f966e",
   "metadata": {},
   "source": [
    "## Recover the original dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95771d97",
   "metadata": {},
   "source": [
    "**Note**: To access the data locally, use the following code:\n",
    "\n",
    "```\n",
    "os.makedirs(\"data\",exist_ok=True)\n",
    "!wget https://covidtracking.com/data/download/all-states-history.csv -P ./data/\n",
    "file_url = \"./data/all-states-history.csv\"\n",
    "df = pd.read_csv(file_url).fillna(value = 0)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37246b14-8a2a-4396-b592-16c1b8b62973",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "from sqlalchemy import create_engine\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "df = pd.read_csv(\"./data/all-states-history.csv\").fillna(value = 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8df5c4c1",
   "metadata": {},
   "source": [
    "## Move the data to the SQL database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb39f31-4ab3-4c95-a55d-4140fd68711f",
   "metadata": {
    "height": 251
   },
   "outputs": [],
   "source": [
    "# Path to your SQLite database file\n",
    "database_file_path = \"./db/test.db\"\n",
    "\n",
    "# Create an engine to connect to the SQLite database\n",
    "# SQLite only requires the path to the database file\n",
    "engine = create_engine(f'sqlite:///{database_file_path}')\n",
    "file_url = \"./data/all-states-history.csv\"\n",
    "df = pd.read_csv(file_url).fillna(value = 0)\n",
    "df.to_sql(\n",
    "    'all_states_history',\n",
    "    con=engine,\n",
    "    if_exists='replace',\n",
    "    index=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8307ac30",
   "metadata": {},
   "source": [
    "## Prepare the SQL prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2e02e43-c499-4897-9981-78e7fbae3491",
   "metadata": {
    "height": 608
   },
   "outputs": [],
   "source": [
    "MSSQL_AGENT_PREFIX = \"\"\"\n",
    "\n",
    "You are an agent designed to interact with a SQL database.\n",
    "## Instructions:\n",
    "- Given an input question, create a syntactically correct {dialect} query\n",
    "to run, then look at the results of the query and return the answer.\n",
    "- Unless the user specifies a specific number of examples they wish to\n",
    "obtain, **ALWAYS** limit your query to at most {top_k} results.\n",
    "- You can order the results by a relevant column to return the most\n",
    "interesting examples in the database.\n",
    "- Never query for all the columns from a specific table, only ask for\n",
    "the relevant columns given the question.\n",
    "- You have access to tools for interacting with the database.\n",
    "- You MUST double check your query before executing it.If you get an error\n",
    "while executing a query,rewrite the query and try again.\n",
    "- DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.)\n",
    "to the database.\n",
    "- DO NOT MAKE UP AN ANSWER OR USE PRIOR KNOWLEDGE, ONLY USE THE RESULTS\n",
    "OF THE CALCULATIONS YOU HAVE DONE.\n",
    "- Your response should be in Markdown. However, **when running  a SQL Query\n",
    "in \"Action Input\", do not include the markdown backticks**.\n",
    "Those are only for formatting the response, not for executing the command.\n",
    "- ALWAYS, as part of your final answer, explain how you got to the answer\n",
    "on a section that starts with: \"Explanation:\". Include the SQL query as\n",
    "part of the explanation section.\n",
    "- If the question does not seem related to the database, just return\n",
    "\"I don\\'t know\" as the answer.\n",
    "- Only use the below tools. Only use the information returned by the\n",
    "below tools to construct your query and final answer.\n",
    "- Do not make up table names, only use the tables returned by any of the\n",
    "tools below.\n",
    "\n",
    "## Tools:\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e70cd29a-c369-4419-ae27-e1f9912b5064",
   "metadata": {
    "height": 693
   },
   "outputs": [],
   "source": [
    "MSSQL_AGENT_FORMAT_INSTRUCTIONS = \"\"\"\n",
    "\n",
    "## Use the following format:\n",
    "\n",
    "Question: the input question you must answer.\n",
    "Thought: you should always think about what to do.\n",
    "Action: the action to take, should be one of [{tool_names}].\n",
    "Action Input: the input to the action.\n",
    "Observation: the result of the action.\n",
    "... (this Thought/Action/Action Input/Observation can repeat N times)\n",
    "Thought: I now know the final answer.\n",
    "Final Answer: the final answer to the original input question.\n",
    "\n",
    "Example of Final Answer:\n",
    "<=== Beginning of example\n",
    "\n",
    "Action: query_sql_db\n",
    "Action Input: \n",
    "SELECT TOP (10) [death]\n",
    "FROM covidtracking \n",
    "WHERE state = 'TX' AND date LIKE '2020%'\n",
    "\n",
    "Observation:\n",
    "[(27437.0,), (27088.0,), (26762.0,), (26521.0,), (26472.0,), (26421.0,), (26408.0,)]\n",
    "Thought:I now know the final answer\n",
    "Final Answer: There were 27437 people who died of covid in Texas in 2020.\n",
    "\n",
    "Explanation:\n",
    "I queried the `covidtracking` table for the `death` column where the state\n",
    "is 'TX' and the date starts with '2020'. The query returned a list of tuples\n",
    "with the number of deaths for each day in 2020. To answer the question,\n",
    "I took the sum of all the deaths in the list, which is 27437.\n",
    "I used the following query\n",
    "\n",
    "```sql\n",
    "SELECT [death] FROM covidtracking WHERE state = 'TX' AND date LIKE '2020%'\"\n",
    "```\n",
    "===> End of Example\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9350e98",
   "metadata": {},
   "source": [
    "## Call the Azure Chat model and create the SQL agent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2816f00e",
   "metadata": {},
   "source": [
    "**Note**: The pre-configured cloud resource grants you access to the Azure OpenAI GPT model. The key and endpoint provided below are intended for teaching purposes only. Your notebook environment is already set up with the necessary keys, which may differ from those used by the instructor during the filming."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9d5ffc3-cea0-48f8-96ce-177bb9cecd92",
   "metadata": {
    "height": 183
   },
   "outputs": [],
   "source": [
    "llm = AzureChatOpenAI(\n",
    "    openai_api_version=\"2023-05-15\",\n",
    "    azure_deployment=\"gpt-4-1106\",\n",
    "    azure_endpoint=os.getenv(\"AZURE_OPENAI_ENDPOINT\"),\n",
    "    temperature=0, \n",
    "    max_tokens=500\n",
    ")\n",
    "\n",
    "db = SQLDatabase.from_uri(f'sqlite:///{database_file_path}')\n",
    "toolkit = SQLDatabaseToolkit(db=db, llm=llm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50e7259a-ee4b-4c8f-a0fa-f747237ad6a4",
   "metadata": {
    "height": 234
   },
   "outputs": [],
   "source": [
    "QUESTION = \"\"\"How may patients were hospitalized during October 2020\n",
    "in New York, and nationwide as the total of all states?\n",
    "Use the hospitalizedIncrease column\n",
    "\"\"\"\n",
    "\n",
    "agent_executor_SQL = create_sql_agent(\n",
    "    prefix=MSSQL_AGENT_PREFIX,\n",
    "    format_instructions = MSSQL_AGENT_FORMAT_INSTRUCTIONS,\n",
    "    llm=llm,\n",
    "    toolkit=toolkit,\n",
    "    top_k=30,\n",
    "    verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fea48e9",
   "metadata": {},
   "source": [
    "## Invoke the SQL model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "953cee48-831b-4b87-98ef-359fafb7492d",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "agent_executor_SQL.invoke(QUESTION)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
